import os
import blechpy

PROJ_DIR = '/data/Katz_Data/Stk11_Project/'

def get_file_dirs(animID):
    anim_dir = os.path.join(PROJ_DIR, animID)
    fd = [os.path.join(anim_dir, x) for x in os.listdir(anim_dir)]
    file_dirs = [x for x in fd if os.path.isdir(x)]
    out = []
    for f in file_dirs:
        fl = os.listdir(f)
        if any([x.endswith('.dat') for x in fl]):
            out.append(f)

    return out, anim_dir


def analyze_data():
    #animals = ['RN28', 'RN29']
    #animals = ['RN30', 'RN31']
    animals = ['RN30']
    dig_in_names = [['Water', 'Quinine', 'NaCl', 'Citric Acid'], ['Saccharin']]
    dead_ch = {'RN28': [7,9,10,20,21,22], 'RN29': [],
               'RN30': [8,12,9,10,20,21,22], 'RN31':[7,8,12]}

    for anim in animals:
        file_dirs, anim_dir = get_file_dirs(anim)
        for fd in file_dirs:
            data_name = '_'.join(os.path.basename(fd).split('_')[:-2])
            #dat = blechpy.dataset(file_dir=fd, data_name=data_name, shell=True)
            dat = blechpy.load_dataset(fd)
            if '4taste' in data_name:
                din = dig_in_names[0]
            else:
                din = dig_in_names[1]

            #dat.initParams(data_quality='clean', emg_port=False,
            #               emg_channels=None, car_keyword='bilateral32',
            #               car_group_areas=['GC', 'GC'],
            #               shell=True, dig_in_names=din,
            #               dig_out_names=None, accept_params=True)
            dat.extract_data()
            #dat.create_trial_list()
            dat.mark_dead_channels(dead_ch[anim])
            dat.common_average_reference()
            dat.save()

    for anim in animals:
        file_dirs, anim_dir = get_file_dirs(anim)
        for fd in file_dirs:
            dat = blechpy.load_dataset(fd)
            dat.detect_spikes()
            dat.blech_clust_run(umap=True)

